import torch
import torch.nn.functional as F
from codebase import utils as ut
from torch import nn
from torch.nn import functional as F
from torch.nn import Linear
device = torch.device("cuda:2" if(torch.cuda.is_available()) else "cpu")

class ConvEncoder(nn.Module):
	def __init__(self, z_dim=8,u_dim=4,out_dim=None):
		super().__init__()
		self.z_dim = z_dim
		self.u_dim = u_dim 
		self.conv1 = torch.nn.Sequential(nn.Conv2d(3,8,3),
                                   nn.MaxPool2d(2),
                                   nn.SELU(),
                                   nn.Conv2d(8,16,3),
                                   nn.MaxPool2d(2),
                                   nn.SELU(),
                                   nn.Conv2d(16,32,3),
                                   nn.MaxPool2d(2),
                                   nn.SELU(),
                                   nn.Conv2d(32,64,3),
                                   nn.MaxPool2d(2),
                                   nn.SELU(), 
                                   nn.Flatten(1))
		self.mean_layer = nn.Sequential(
			torch.nn.Linear(256, self.z_dim)
			) 
		self.var_layer = nn.Sequential(
			torch.nn.Linear(256, self.z_dim)
			)    
	def encode(self, x, u=None):
		x = self.conv1(x)
		mu, logvar = self.mean_layer(x), self.var_layer(x)
		return  mu, logvar

class ConvDecoder(nn.Module):
	def __init__(self, z_dim,out_dim = None):
		super().__init__()
		self.z_dim=z_dim   
		self.net1 = nn.Linear(z_dim,256)
		self.net2 = nn.SELU()
		self.net3 = nn.Linear(256,16*12*12)
		self.net4 = nn.SELU()
		self.net5 = nn.Sequential(nn.Conv2d(16,32,2),
				nn.SELU(),
				nn.Conv2d(32,64,2),
				nn.SELU(),
				nn.Conv2d(64,128,2),
				nn.SELU(),
				nn.Flatten(1),
				nn.Linear(128*9*9,3*64*64),
				nn.Tanh())
   
	def decode(self, z):
		h=self.net1(z)
		h=self.net2(h)
		h=self.net3(h)
		h=self.net4(h)
		h=h.view(-1,16,12,12)
		x=self.net5(h)
		return x

